{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# MNIST Image Classification with TensorFlow\n",
    "\n",
    "This notebook demonstrates how to implement a simple linear image models on MNIST using tf.keras.\n",
    "<hr/>\n",
    "This <a href=\"mnist_models.ipynb\">companion notebook</a> extends the basic harness of this notebook to a variety of models including DNN, CNN, dropout, pooling etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
   "!sudo chown -R jupyter:jupyter /home/jupyter/training-data-analyst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/compat/v2_compat.py:88: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "non-resource variables are not supported in the long term\n",
      "2.1.0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import shutil\n",
    "import os\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Exploring the data\n",
    "\n",
    "Let's download MNIST data and examine the shape. We will need these numbers ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "HEIGHT = 28\n",
    "WIDTH = 28\n",
    "NCLASSES = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train.shape = (60000, 28, 28)\n",
      "y_train.shape = (60000, 10)\n",
      "x_test.shape = (10000, 28, 28)\n",
      "y_test.shape = (10000, 10)\n"
     ]
    }
   ],
   "source": [
    "# Get mnist data\n",
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "# Scale our features between 0 and 1\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0 \n",
    "\n",
    "# Convert labels to categorical one-hot encoding\n",
    "y_train = tf.keras.utils.to_categorical(y = y_train, num_classes = NCLASSES)\n",
    "y_test = tf.keras.utils.to_categorical(y = y_test, num_classes = NCLASSES)\n",
    "\n",
    "print(\"x_train.shape = {}\".format(x_train.shape))\n",
    "print(\"y_train.shape = {}\".format(y_train.shape))\n",
    "print(\"x_test.shape = {}\".format(x_test.shape))\n",
    "print(\"y_test.shape = {}\".format(y_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADg5JREFUeJzt3XuMXPV5xvHnwawNmEtsaDYuuDFNKKlDywIb0whaSEgiYiUFqhZhqanT0jhSAyoVSYNADfxRKahtEkhLUU1wYyIuScrFboVSqGuJRCEuCzg2xlButrBlbBLT2EnA2N63f+xxtIGd36zndmZ5vx9pNTPnPWfOqyM/PjPzOzM/R4QA5HNI3Q0AqAfhB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+Q1KG93Nl0z4jDNLOXuwRSeU0/0+uxx5NZt63w2z5f0o2Spkn6WkRcX1r/MM3UmT6vnV0CKFgTqya9bssv+21Pk3STpI9Kmi9pke35rT4fgN5q5z3/AknPRsTzEfG6pLskXdCZtgB0WzvhP17Si+Meb6mW/RLbS2yP2B7Zqz1t7A5AJ3X90/6IWBoRwxExPKAZ3d4dgElqJ/xbJc0d9/iEahmAKaCd8D8i6STbJ9qeLukSSSs70xaAbmt5qC8i9tm+TNJ/amyob1lEbOhYZwC6qq1x/oi4X9L9HeoFQA9xeS+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJtTVLr+1NknZL2i9pX0QMd6IpAN3XVvgrH4iIH3XgeQD0EC/7gaTaDX9IesD2o7aXdKIhAL3R7sv+syNiq+23S3rQ9lMR8dD4Far/FJZI0mE6os3dAeiUts78EbG1ut0h6V5JCyZYZ2lEDEfE8IBmtLM7AB3Ucvhtz7R91IH7kj4i6YlONQagu9p52T8o6V7bB57njoj4Tke6AtB1LYc/Ip6XdGoHewHQQwz1AUkRfiApwg8kRfiBpAg/kBThB5LqxLf60Mf2n3t6sX7oF7YX6/9+8spifcDTivW9sb9h7ay1lxS3PfaagWLdm7YW6z/++PyGtdn3la9HG929u1h/K+DMDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJMc4/BXhG+ReQdv/+UMPatV9cVtz2nMN/XqyPFqvS3ijXRwvP8N2hO4rbnv43nyzWT31H+dy1Yt4/Nay9722XF7cd/MfvF+tvBZz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApxvmngD3n/lax/t83NB7Pbmb1q0cW61/42z8r1gd+3mSgv2DXO8vnnunlSxD0158tX8Pwk9F9DWtHbmv8OwNZcOYHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaSajvPbXibpY5J2RMQp1bLZkr4paZ6kTZIujohXutfmW1u8vzzT+Rdv/peWn3vRcwuL9V3Xzi3WZ61+uOV9N3PMu08s1oe+/Vyx/pvTy+eu96z4q4a13/i3NcVtM5jMmf/rks5/w7KrJK2KiJMkraoeA5hCmoY/Ih6StPMNiy+QtLy6v1zShR3uC0CXtfqefzAitlX3X5I02KF+APRI2x/4RURIaniBt+0ltkdsj+zVnnZ3B6BDWg3/dttzJKm63dFoxYhYGhHDETE8oPIPUQLonVbDv1LS4ur+YkkrOtMOgF5pGn7bd0p6WNLJtrfYvlTS9ZI+bPsZSR+qHgOYQpqO80fEogal8zrcS1qvXPNqsX5Gk3dLC5/6g4a1aZ89urjttMcfKz95F/3fGeXPia99+7faev65D7S1+VseV/gBSRF+ICnCDyRF+IGkCD+QFOEHkuKnu3vghbt+u1jfcNq/Futb9pWHAg+5ZlbDWjy+rrhtt5WmF3/3FU8Wtz2kybnpTzeXR5sPv+9/ivXsOPMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKM8/fAn8wvjzeParRY37yv/LVc/aC+sfzSOL4kPX1D458lX/FrNxW3LR8VafPfn1ysHyF+nruEMz+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMU4P4qmvbc8lr7x8mOK9ac+Xh7LL1n96pHF+lHff6FY39/ynnPgzA8kRfiBpAg/kBThB5Ii/EBShB9IivADSTUd57e9TNLHJO2IiFOqZddJ+pSkl6vVro6I+7vV5FR39wtDxfrnjl1frJ8242fF+u+ue+2ge5qsBUfcU6x/4PDyvpt9J7/kyh/+YbF+wvYNbTw7JnPm/7qk8ydY/pWIGKr+CD4wxTQNf0Q8JGlnD3oB0EPtvOe/zPY628tsN54vCkBfajX8N0t6l6QhSdskfanRiraX2B6xPbJXe1rcHYBOayn8EbE9IvZHxKikWyQtKKy7NCKGI2J4QOUfewTQOy2F3/accQ8vkvREZ9oB0CuTGeq7U9K5ko6zvUXStZLOtT0kKSRtkvTpLvYIoAscET3b2dGeHWe6PKf6W9EhRx1VrI/eV/5O/H+8Z0V5+7ZG09tzzucvL9ZHF/24Ye27Q3cUtz3/0r8o1qd/55FiPaM1sUq7Yqcnsy5X+AFJEX4gKcIPJEX4gaQIP5AU4QeS4qe7e2B09+7yCueV6x+8qDzkteOM1v8Pn7WxPNR7zO0/KNZf/kb5ku2nhu5qWLv1J/OK2x6xYVuxvq9YRTOc+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcb5p4Aj7l1TrM+7t0eNTOCpD36tWC993fimp88pbvurLz7ZUk+YHM78QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4/womvbek5us8Wixunnf6w1rg189rIWO0Cmc+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqabj/LbnSrpN0qCkkLQ0Im60PVvSNyXNk7RJ0sUR8Ur3WkUdnr92elvb/9Hjf96w9o7Vj7X13GjPZM78+yRdGRHzJf2OpM/Yni/pKkmrIuIkSauqxwCmiKbhj4htEfFYdX+3pI2Sjpd0gaTl1WrLJV3YrSYBdN5Bvee3PU/SaZLWSBqMiAPzKb2ksbcFAKaISYff9pGS7pZ0RUTsGl+LiNDY5wETbbfE9ojtkb0qz+sGoHcmFX7bAxoL/u0RcU+1eLvtOVV9jqQdE20bEUsjYjgihgc0oxM9A+iApuG3bUm3StoYEV8eV1opaXF1f7GkFZ1vD0C3TOYrvWdJ+oSk9bbXVsuulnS9pG/ZvlTSZkkXd6dFdFO8/9RifeWZ/9zkGcpfy/WqWQfZEXqlafgj4nuS3KB8XmfbAdArXOEHJEX4gaQIP5AU4QeSIvxAUoQfSIqf7k5ux/tmFusnHloexy9NwS1Jh7424VXf6AOc+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcb5k3vtuPI4fLNx/Bt2zi/Wj73l4YPuCb3BmR9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmKcP7k/vnB1W9svW/GhYn2eGOfvV5z5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCppuP8tudKuk3SoKSQtDQibrR9naRPSXq5WvXqiLi/W42iO+5+YahY/9yx63vUCXptMhf57JN0ZUQ8ZvsoSY/afrCqfSUi/qF77QHolqbhj4htkrZV93fb3ijp+G43BqC7Duo9v+15kk6TtKZadJntdbaX2Z7VYJsltkdsj+zVnraaBdA5kw6/7SMl3S3piojYJelmSe+SNKSxVwZfmmi7iFgaEcMRMTygGR1oGUAnTCr8tgc0FvzbI+IeSYqI7RGxPyJGJd0iaUH32gTQaU3Db9uSbpW0MSK+PG75nHGrXSTpic63B6BbJvNp/1mSPiFpve211bKrJS2yPaSx4b9Nkj7dlQ7RVbFqdrF+9QlnFuuDI/s72Q56aDKf9n9PkicoMaYPTGFc4QckRfiBpAg/kBThB5Ii/EBShB9IyhHlKZo76WjPjjN9Xs/2B2SzJlZpV+ycaGj+TTjzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSPR3nt/2ypM3jFh0n6Uc9a+Dg9Gtv/dqXRG+t6mRv74yIX5nMij0N/5t2bo9ExHBtDRT0a2/92pdEb62qqzde9gNJEX4gqbrDv7Tm/Zf0a2/92pdEb62qpbda3/MDqE/dZ34ANakl/LbPt/207WdtX1VHD43Y3mR7ve21tkdq7mWZ7R22nxi3bLbtB20/U91OOE1aTb1dZ3trdezW2l5YU29zba+2/aTtDbb/slpe67Er9FXLcev5y37b0yT9r6QPS9oi6RFJiyLiyZ420oDtTZKGI6L2MWHbvyfpp5Jui4hTqmV/J2lnRFxf/cc5KyI+3ye9XSfpp3XP3FxNKDNn/MzSki6U9EnVeOwKfV2sGo5bHWf+BZKejYjnI+J1SXdJuqCGPvpeRDwkaecbFl8gaXl1f7nG/vH0XIPe+kJEbIuIx6r7uyUdmFm61mNX6KsWdYT/eEkvjnu8Rf015XdIesD2o7aX1N3MBAaradMl6SVJg3U2M4GmMzf30htmlu6bY9fKjNedxgd+b3Z2RJwu6aOSPlO9vO1LMfaerZ+GayY1c3OvTDCz9C/UeexanfG60+oI/1ZJc8c9PqFa1hciYmt1u0PSveq/2Ye3H5gktbrdUXM/v9BPMzdPNLO0+uDY9dOM13WE/xFJJ9k+0fZ0SZdIWllDH29ie2b1QYxsz5T0EfXf7MMrJS2u7i+WtKLGXn5Jv8zc3GhmadV87PpuxuuI6PmfpIUa+8T/OUnX1NFDg75+XdIPq78Ndfcm6U6NvQzcq7HPRi6VdKykVZKekfRfkmb3UW/fkLRe0jqNBW1OTb2drbGX9Oskra3+FtZ97Ap91XLcuMIPSIoP/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPX/rJw9J1q+cE8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "IMGNO = 12\n",
    "plt.imshow(x_test[IMGNO].reshape(HEIGHT, WIDTH));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Define the model.\n",
    "Let's start with a very simple linear classifier. All our models will have this basic interface -- they will take an image and return probabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# Build Keras Model Using Keras Sequential API\n",
    "def linear_model():\n",
    "    model = tf.keras.models.Sequential()\n",
    "    model.add(tf.keras.layers.InputLayer(input_shape = [HEIGHT, WIDTH], name = \"image\"))\n",
    "    model.add(tf.keras.layers.Flatten())\n",
    "    model.add(tf.keras.layers.Dense(units = NCLASSES, activation = tf.nn.softmax, name = \"probabilities\"))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Write Input Functions\n",
    "\n",
    "As usual, we need to specify input functions for training, evaluation, and predicition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# Create training input function\n",
    "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
    "    x = {\"image\": x_train},\n",
    "    y = y_train,\n",
    "    batch_size = 100,\n",
    "    num_epochs = None,\n",
    "    shuffle = True,\n",
    "    queue_capacity = 5000\n",
    "  )\n",
    "\n",
    "# Create evaluation input function\n",
    "eval_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
    "    x = {\"image\": x_test},\n",
    "    y = y_test,\n",
    "    batch_size = 100,\n",
    "    num_epochs = 1,\n",
    "    shuffle = False,\n",
    "    queue_capacity = 5000\n",
    "  )\n",
    "\n",
    "# Create serving input function for inference\n",
    "def serving_input_fn():\n",
    "    placeholders = {\"image\": tf.placeholder(dtype = tf.float32, shape = [None, HEIGHT, WIDTH])}\n",
    "    features = placeholders # as-is\n",
    "    return tf.estimator.export.ServingInputReceiver(features = features, receiver_tensors = placeholders)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Create train_and_evaluate function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    " tf.estimator.train_and_evaluate does distributed training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train_and_evaluate(output_dir, hparams):\n",
    "    # Build Keras model\n",
    "    model = linear_model()\n",
    "        \n",
    "    # Compile Keras model with optimizer, loss function, and eval metrics\n",
    "    model.compile(\n",
    "        optimizer = \"adam\",\n",
    "        loss = \"categorical_crossentropy\",\n",
    "        metrics = [\"accuracy\"])\n",
    "        \n",
    "    # Convert Keras model to an Estimator\n",
    "    estimator = tf.keras.estimator.model_to_estimator(\n",
    "        keras_model = model, \n",
    "        model_dir = output_dir)\n",
    "\n",
    "    # Set estimator's train_spec to use train_input_fn and train for so many steps\n",
    "    train_spec = tf.estimator.TrainSpec(\n",
    "        input_fn = train_input_fn,\n",
    "        max_steps = hparams[\"train_steps\"])\n",
    "\n",
    "    # Create exporter that uses serving_input_fn to create saved_model for serving\n",
    "    exporter = tf.estimator.LatestExporter(\n",
    "        name = \"exporter\", \n",
    "        serving_input_receiver_fn = serving_input_fn)\n",
    "\n",
    "    # Set estimator's eval_spec to use eval_input_fn and export saved_model\n",
    "    eval_spec = tf.estimator.EvalSpec(\n",
    "        input_fn = eval_input_fn,\n",
    "        steps = None,\n",
    "        exporters = exporter)\n",
    "\n",
    "    # Run train_and_evaluate loop\n",
    "    tf.estimator.train_and_evaluate(\n",
    "        estimator = estimator, \n",
    "        train_spec = train_spec, \n",
    "        eval_spec = eval_spec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "This is the main() function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "INFO:tensorflow:Using the Keras model provided.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "INFO:tensorflow:Using config: {'_eval_distribute': None, '_session_creation_timeout_secs': 7200, '_save_summary_steps': 100, '_global_id_in_cluster': 0, '_service': None, '_experimental_distribute': None, '_num_ps_replicas': 0, '_cluster_spec': ClusterSpec({}), '_train_distribute': None, '_tf_random_seed': None, '_experimental_max_worker_delay_secs': None, '_protocol': None, '_log_step_count_steps': 100, '_model_dir': 'mnist/learned', '_is_chief': True, '_evaluation_master': '', '_keep_checkpoint_max': 5, '_device_fn': None, '_keep_checkpoint_every_n_hours': 10000, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_save_checkpoints_secs': 600, '_task_id': 0, '_save_checkpoints_steps': None, '_task_type': 'worker', '_master': '', '_num_worker_replicas': 1}\n",
      "INFO:tensorflow:Not using Distribute Coordinator.\n",
      "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n",
      "INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:500: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='mnist/learned/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})\n",
      "INFO:tensorflow:Warm-starting from: mnist/learned/keras/keras_model.ckpt\n",
      "INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.\n",
      "INFO:tensorflow:Warm-started 2 variables.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/training/monitored_session.py:906: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into mnist/learned/model.ckpt.\n",
      "INFO:tensorflow:loss = 2.3702068, step = 1\n",
      "INFO:tensorflow:global_step/sec: 340.867\n",
      "INFO:tensorflow:loss = 0.7732622, step = 101 (0.295 sec)\n",
      "INFO:tensorflow:global_step/sec: 384.292\n",
      "INFO:tensorflow:loss = 0.6145685, step = 201 (0.260 sec)\n",
      "INFO:tensorflow:global_step/sec: 391.351\n",
      "INFO:tensorflow:loss = 0.50422275, step = 301 (0.256 sec)\n",
      "INFO:tensorflow:global_step/sec: 397.311\n",
      "INFO:tensorflow:loss = 0.4027025, step = 401 (0.252 sec)\n",
      "INFO:tensorflow:global_step/sec: 391.875\n",
      "INFO:tensorflow:loss = 0.44188076, step = 501 (0.258 sec)\n",
      "INFO:tensorflow:global_step/sec: 386.557\n",
      "INFO:tensorflow:loss = 0.37594798, step = 601 (0.256 sec)\n",
      "INFO:tensorflow:global_step/sec: 389.035\n",
      "INFO:tensorflow:loss = 0.34926274, step = 701 (0.257 sec)\n",
      "INFO:tensorflow:global_step/sec: 400.38\n",
      "INFO:tensorflow:loss = 0.52942824, step = 801 (0.250 sec)\n",
      "INFO:tensorflow:global_step/sec: 391.907\n",
      "INFO:tensorflow:loss = 0.40796912, step = 901 (0.255 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1000 into mnist/learned/model.ckpt.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Starting evaluation at 2020-06-02T17:25:49Z\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from mnist/learned/model.ckpt-1000\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Inference Time : 0.65976s\n",
      "INFO:tensorflow:Finished evaluation at 2020-06-02-17:25:50\n",
      "INFO:tensorflow:Saving dict for global step 1000: acc = 0.9136, global_step = 1000, loss = 0.3200383\n",
      "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: mnist/learned/model.ckpt-1000\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.5/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n",
      "INFO:tensorflow:Signatures INCLUDED in export for Classify: None\n",
      "INFO:tensorflow:Signatures INCLUDED in export for Train: None\n",
      "INFO:tensorflow:Signatures INCLUDED in export for Regress: None\n",
      "INFO:tensorflow:Signatures INCLUDED in export for Eval: None\n",
      "INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']\n",
      "INFO:tensorflow:Restoring parameters from mnist/learned/model.ckpt-1000\n",
      "INFO:tensorflow:Assets added to graph.\n",
      "INFO:tensorflow:No assets to write.\n",
      "INFO:tensorflow:SavedModel written to: mnist/learned/export/exporter/temp-b'1591118750'/saved_model.pb\n",
      "INFO:tensorflow:Loss for final step: 0.45906192.\n"
     ]
    }
   ],
   "source": [
    "OUTDIR = \"mnist/learned\"\n",
    "shutil.rmtree(OUTDIR, ignore_errors = True) # start fresh each time\n",
    "\n",
    "hparams = {\"train_steps\": 1000, \"learning_rate\": 0.01}\n",
    "train_and_evaluate(OUTDIR, hparams)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I got:\n",
    "\n",
    "`Saving dict for global step 1000: categorical_accuracy = 0.9112, global_step = 1000, loss = 0.32516304`\n",
    "\n",
    "In other words, we achieved 91.12% accuracy with the simple linear model!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "<pre>\n",
    "# Copyright 2020 Google Inc. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "</pre>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
